{
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "install required pacages"
      ],
      "metadata": {
        "id": "T18eK3oC4Of7"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!pip install openai lancedb langchain openai tiktoken sentence_transformers arxiv pymupdf gradio ArxivLoader langchain-community langchain-openai -q"
      ],
      "metadata": {
        "id": "5pUjVpum3luB",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "f74fe0c8-9946-4b62-afc5-dd21ed8d6da9"
      },
      "execution_count": 9,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\u001b[?25l   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/52.0 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m52.0/52.0 kB\u001b[0m \u001b[31m1.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25h"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# restart the runtime & now drectly run below code ### dont need to install pacages again"
      ],
      "metadata": {
        "id": "fBVuKpU43XMP"
      },
      "execution_count": 2,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import langchain\n",
        "from langchain.document_loaders import ArxivLoader\n",
        "from langchain.embeddings import OpenAIEmbeddings\n",
        "from langchain.llms import OpenAI\n",
        "from langchain.chains import FlareChain\n",
        "from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
        "\n",
        "# Uncomment the below if you want to see all the intermediate steps\n",
        "# langchain.verbose=True"
      ],
      "metadata": {
        "id": "QZgTTVUH6HFU"
      },
      "execution_count": 5,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from langchain import PromptTemplate, LLMChain\n",
        "from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
        "from langchain.vectorstores import Chroma\n",
        "from langchain.chains import RetrievalQA\n",
        "from langchain.embeddings import HuggingFaceBgeEmbeddings\n",
        "from io import BytesIO\n",
        "from langchain.document_loaders import PyPDFLoader\n",
        "import langchain\n",
        "from langchain.document_loaders import ArxivLoader\n",
        "import gradio as gr\n",
        "import lancedb\n",
        "from langchain.vectorstores import LanceDB\n",
        "from langchain.document_loaders import ArxivLoader\n",
        "from langchain.chains import FlareChain\n",
        "from langchain.prompts import PromptTemplate\n",
        "from langchain.chains import LLMChain\n",
        "import os\n",
        "from langchain.llms import OpenAI\n",
        "import getpass\n",
        "\n",
        "os.environ[\"OPENAI_API_KEY\"] = \"sk-proj-......\"\n",
        "\n",
        "llm = OpenAI()\n",
        "\n",
        "model_name = \"BAAI/bge-large-en\"\n",
        "model_kwargs = {\"device\": \"cuda\"}\n",
        "encode_kwargs = {\"normalize_embeddings\": False}\n",
        "embeddings = HuggingFaceBgeEmbeddings(\n",
        "    model_name=model_name, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs\n",
        ")\n",
        "\n",
        "\n",
        "# instantiate llm\n",
        "\n",
        "# instantiate embeddings model\n",
        "# embeddings = OpenAIEmbeddings()\n",
        "\n",
        "# fetch docs from arxiv, in this case it's the FLARE paper\n",
        "docs = ArxivLoader(query=\"2305.06983\", load_max_docs=2).load()\n",
        "\n",
        "\n",
        "# here is example https://arxiv.org/pdf/2305.06983.pdf\n",
        "# you need to pass this number to query 2305.06983\n",
        "# fetch docs from arxiv, in this case it's the FLARE paper\n",
        "docs = ArxivLoader(query=\"2305.06983\", load_max_docs=1).load()\n",
        "\n",
        "# instantiate text splitter\n",
        "text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=150)\n",
        "\n",
        "# split the document into chunks\n",
        "doc_chunks = text_splitter.split_documents(docs)\n",
        "\n",
        "# lancedb vectordb\n",
        "db = lancedb.connect(\"/tmp/lancedb\")\n",
        "table = db.create_table(\n",
        "    \"documentsai\",\n",
        "    data=[\n",
        "        {\n",
        "            \"vector\": embeddings.embed_query(\"Hello World\"),\n",
        "            \"text\": \"Hello World\",\n",
        "            \"id\": \"1\",\n",
        "        }\n",
        "    ],\n",
        "    mode=\"overwrite\",\n",
        ")\n",
        "vector_store = LanceDB.from_documents(doc_chunks, embeddings)\n",
        "\n",
        "vector_store_retriever = vector_store.as_retriever()\n",
        "\n",
        "flare = FlareChain.from_llm(\n",
        "    llm=llm, retriever=vector_store_retriever, max_generation_len=300, min_prob=0.45\n",
        ")\n",
        "\n",
        "\n",
        "# Define a function to generate FLARE output based on user input\n",
        "def generate_flare_output(input_text):\n",
        "    output = flare.run(input_text)\n",
        "    return output\n",
        "\n",
        "\n",
        "input = gr.Text(\n",
        "    label=\"Prompt\",\n",
        "    show_label=False,\n",
        "    max_lines=1,\n",
        "    placeholder=\"Enter your prompt\",\n",
        "    container=False,\n",
        ")\n",
        "\n",
        "iface = gr.Interface(\n",
        "    fn=generate_flare_output,\n",
        "    inputs=input,\n",
        "    outputs=\"text\",\n",
        "    title=\"My AI bot\",\n",
        "    description=\"FLARE implementation with lancedb & bge embedding.\",\n",
        ")\n",
        "\n",
        "iface.launch(debug=True, share=True)"
      ],
      "metadata": {
        "id": "gtPkyqdb1Jv6",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 650
        },
        "outputId": "cfce1052-6526-4f72-a192-4336b6a09512"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().\n",
            "Running on public URL: https://095bfeb8a6f71d44e2.gradio.live\n",
            "\n",
            "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ],
            "text/html": [
              "<div><iframe src=\"https://095bfeb8a6f71d44e2.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
            ]
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "<ipython-input-10-af386cbd1db6>:77: LangChainDeprecationWarning: The method `Chain.run` was deprecated in langchain 0.1.0 and will be removed in 1.0. Use invoke instead.\n",
            "  output = flare.run(input_text)\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "AuX5eYG2LA-y"
      },
      "execution_count": null,
      "outputs": []
    }
  ],
  "metadata": {
    "colab": {
      "toc_visible": true,
      "provenance": [],
      "gpuType": "T4"
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "accelerator": "GPU"
  },
  "nbformat": 4,
  "nbformat_minor": 0
}